Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of flash attention for native webgpu ep #22932

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

sushraja-msft
Copy link

@sushraja-msft sushraja-msft commented Nov 24, 2024

Description

This change implements flash attention in webgpu, to improve prefill speed.
Perf numbers from Intel Alderlake device

Baseline MHA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       2.26746e+07
        avg (tokens/s): 22.0952              <<<
        p50 (us):       2.34637e+07
        stddev (us):    3.92912e+06
        n:              5 * 501 token(s)
Token generation:
        avg (us):       96519.8
        avg (tokens/s): 10.3606              <<<
        p50 (us):       98061.5
        stddev (us):    9220.87
        n:              635 * 1 token(s)

With FA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.69236e+07
        avg (tokens/s): 29.6036             <<<
        p50 (us):       1.63162e+07
        stddev (us):    960417
        n:              5 * 501 token(s)
Token generation:
        avg (us):       91436.7
        avg (tokens/s): 10.9365             <<<
        p50 (us):       90397.1
        stddev (us):    5349.19
        n:              635 * 1 token(s)

Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 30% speed up measured here.

This implementation also uses new webgpu feature subgroups to further accelerate attention computation.

  • Tested on Intel Alderlake (Subgroup Size 16) with Phi 3.5 mini.
  • Tested on Nvidia 2060 (Subgroup Size 32) with Phi 3.5 mini.
  • Tested with Lama 3.2 1B parameters, FlashAttention does not activate because past/present keys are always null. Needs investigation into the model to understand why this is the case.

Remaining work

  • Algorithm specialization for generation phase, here memory tiles for K/V can be removed because each K/V values are used just once creating more Shared memory space for larger tile size.
  • Algorithm specialization for no past KV case (prefill case). The CopyKVCache operation can likely be eliminated in this case, as there is no past KV values to copy over, new KV values can be copied to present KV as part of flash attention. PIX profiling shows CopyKVCache is almost as expensive as FlashAttention implementation. StaticKV cache will also eliminate this and result in more performance wins.

How to enable

Currently flash attention is off by default. To enable use
"provider_options": [
{
"webgpu": { "enableFlashAttention" : "1" }
}
]
in genai_config.json.

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@guschmue
Copy link
Contributor

Very cool, I can give it a test drive on some other gpu's and macos.

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants